import torch
import torch.nn as nn
import torch.nn.functional as F


def truncated_normal_(tensor, mean: float = 0., std: float = 1.):  
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)




class ActivationLayer(torch.nn.Module):
    def __init__(self,
                in_features: int,
                out_features: int):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty((in_features, out_features)))
        self.bias = torch.nn.Parameter(torch.empty(out_features))

    def forward(self, x):
        raise NotImplementedError("abstract methodd called")


class ReLUUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        if in_features > 0 and out_features > 0:
            torch.nn.init.xavier_uniform_(self.weight)
            truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation

    def forward(self, x):

        
        out = (x) @ self.weight + self.bias
        out = F.relu(out)
        return out


class LinearUnit(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        if in_features > 0 and out_features > 0:
            torch.nn.init.xavier_uniform_(self.weight)
            truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation

    def forward(self, x):

        
        out = (x) @ self.weight + self.bias
        return out


class ExpUnit_non_bias(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu_n'):
        super().__init__(in_features, out_features)
        torch.nn.init.uniform_(self.weight,a=-5.0, b=1.0)
        truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation
        
    def forward(self, x):

        
        out = (x) @ torch.exp(self.weight)
        return out

class ClipUnit_non_bias(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu_n'):
        super().__init__(in_features, out_features)
        torch.nn.init.xavier_uniform_(self.weight)
        truncated_normal_(self.bias, std=0.5)
        self.activation = activation
        
    def forward(self, x):

        weight_pos = torch.clamp(self.weight, min=0.0)
        out = x @ weight_pos
        return out

class LinearUnit_non_bias(ActivationLayer):
    def __init__(self,
                in_features: int,
                out_features: int,
                activation: str = 'relu'):
        super().__init__(in_features, out_features)
        if in_features > 0 and out_features > 0:
            torch.nn.init.xavier_uniform_(self.weight)
            truncated_normal_(self.bias, std=0.5)
        self.size = in_features
        self.activation = activation

    def forward(self, x):

        
        out = (x) @ self.weight
        return out

    
    



class SCNN(nn.Module):
    def __init__(self,
                input_size:int,
                x_c_features: list,
                x_s_features: list,
                u_dims: tuple=(),
                z_dims: tuple=(),
                linear_unit: ActivationLayer = LinearUnit, # Relu-n(exp(W)*X+b)
                relu_unit: ActivationLayer = ReLUUnit, # Relu(exp(W)*X+b)
                linear_unit_non_bias: ActivationLayer = LinearUnit_non_bias, # Relu(W*X+b)
                exp_unit_non_bias: ActivationLayer = ClipUnit_non_bias, #ClipUnit_non_bias, # Relu(W*X+b)
                is_convex: bool = True):
        super(SCNN,self).__init__()
        
        self.input_size = input_size
        self.x_c_features = x_c_features
        self.x_s_features = x_s_features
        self.x_u_features = list(set(list(range(input_size))).difference(x_c_features,x_s_features))
        
        self.x_c_size = len(x_c_features)
        self.x_s_size = len(x_s_features)
        self.x_u_size = input_size - self.x_c_size - self.x_s_size
        self.u_dims = u_dims  
        self.z_dims = z_dims  
        self.is_convex = is_convex
        self.n_layers = len(u_dims)
        
        # activation
        self.g_act = F.relu if is_convex else (lambda x: -F.relu(-x))
        self.h_act = F.relu

        self.u_layers = torch.nn.ModuleList(
            [linear_unit(self.x_u_size,u_dims[i]) if i == 0 else linear_unit(u_dims[i-1],u_dims[i]) for i in range(self.n_layers)]
        )
        
        self.skip = torch.nn.ModuleList(
            [linear_unit(self.x_u_size,z_dims[i]) if i == 0 else linear_unit(u_dims[i-1],z_dims[i]) for i in range(self.n_layers)]
        )

        self.xc_layers = torch.nn.ModuleList(
            [linear_unit(self.x_u_size,self.x_c_size) if i == 0 else linear_unit(u_dims[i-1],self.x_c_size) for i in range(self.n_layers)]
        )
        self.xc_outer = torch.nn.ModuleList(
            [linear_unit_non_bias(self.x_c_size,z_dims[i]) if i == 0 else linear_unit_non_bias(self.x_c_size,z_dims[i]) for i in range(self.n_layers)]
        )

        self.xs_layers = torch.nn.ModuleList(
            [relu_unit(self.x_u_size,self.x_s_size) if i == 0 else relu_unit(u_dims[i-1],self.x_s_size) for i in range(self.n_layers)]
        )
        self.xs_outer = torch.nn.ModuleList(
            [exp_unit_non_bias(self.x_s_size,z_dims[i]) if i == 0 else exp_unit_non_bias(self.x_s_size,z_dims[i]) for i in range(self.n_layers)]
        )
        
        self.z_main = torch.nn.ModuleList(
            [relu_unit(self.x_u_size,z_dims[i]) if i == 0 else relu_unit(u_dims[i-1],z_dims[i-1]) for i in range(self.n_layers)]
        )        
        self.z_main_outer = torch.nn.ModuleList(
            [exp_unit_non_bias(self.x_u_size,z_dims[i]) if i == 0 else exp_unit_non_bias(z_dims[i-1],z_dims[i]) for i in range(self.n_layers)]
        )
        
    def forward(self, x):

        x_c  = x[:, self.x_c_features]
        x_s  = x[:, self.x_s_features]
        x_u  = x[:, self.x_u_features]
        


        for i in range(self.n_layers):
            
            if self.x_u_size != 0:
                if i == 0 :
                    # --- u layer --- #

                    u_output = self.u_layers[i](x_u) 
                    u_output = self.h_act(u_output)

                    z_skip = self.skip[i](x_u)

                    z_pre = z_skip 
                    if self.x_c_size != 0:
                        z_x_c = self.xc_layers[i](x_u)

                        z_x_c = self.xc_outer[i](x_c*z_x_c)

                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_layers[i](x_u)

                        z_x_s = self.xs_outer[i](x_s*z_x_s)

                        z_pre += z_x_s
                    
                    z_output = self.g_act(z_pre)
                
                elif i == self.n_layers - 1:
                    
                    u_output_next = self.u_layers[i](u_output) 
                    u_output_next = self.h_act(u_output_next)

                    z_skip = self.skip[i](u_output)

                    z_pre = z_skip 
                    if self.x_c_size != 0:
                        z_x_c = self.xc_layers[i](u_output)
                        z_x_c = self.xc_outer[i](x_c*z_x_c)
                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_layers[i](u_output)
                        z_x_s = self.xs_outer[i](x_s*z_x_s)
                        z_pre += z_x_s
                        
                    z_main = self.z_main[i](u_output)
                    z_main = self.z_main_outer[i](z_output*z_main)
                    z_pre += z_main
                    z_output = z_pre
                    u_output = u_output_next
                    
                else :

                    u_output_next = self.u_layers[i](u_output) 
                    u_output_next = self.h_act(u_output_next)

                    z_skip = self.skip[i](u_output)

                    z_pre = z_skip 
                    if self.x_c_size != 0:
                        z_x_c = self.xc_layers[i](u_output)
                        z_x_c = self.xc_outer[i](x_c*z_x_c)
                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_layers[i](u_output)
                        z_x_s = self.xs_outer[i](x_s*z_x_s)
                        z_pre += z_x_s
                        
                    z_main = self.z_main[i](u_output)
                    z_main = self.z_main_outer[i](z_output*z_main)
                    z_pre += z_main
                    z_output = self.g_act(z_pre)
                    u_output = u_output_next
                    
            else :
                if i == 0 :
                    
                    z_pre = 0 
                    if self.x_c_size != 0:
                        z_x_c = self.xc_outer[i](x_c)
                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_outer[i](x_s)
                        z_pre += z_x_s
                    
                    z_output = self.g_act(z_pre)
                
                
                elif i == self.n_layers - 1:
                    
                    z_pre = 0
                    if self.x_c_size != 0:
                        z_x_c = self.xc_outer[i](x_c)
                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_outer[i](x_s)
                        z_pre += z_x_s
                    
                    z_main = self.z_main_outer[i](z_output) 
                    z_pre += z_main
                    z_output = z_pre
                    
                else :
                
                    z_pre = 0
                    if self.x_c_size != 0:
                        z_x_c = self.xc_outer[i](x_c)
                        z_pre += z_x_c
                    if self.x_s_size != 0:
                        z_x_s = self.xs_outer[i](x_s)
                        z_pre += z_x_s
                    
                    z_main = self.z_main_outer[i](z_output) 
                    z_pre += z_main
                    z_output = self.g_act(z_pre)


        return z_output


        
        

if __name__ == "__main__":
    # batch_size = 7
    # xu_dim, xc_dim, xs_dim = 11, 5, 3
    # features = {
    #     'xu': torch.randn(batch_size, xu_dim),
    #     'xc': torch.randn(batch_size, xc_dim),
    #     'xs': torch.randn(batch_size, xs_dim),
    # }

    # model = SCNN(xu_dim, xc_dim, xs_dim, u_dims=[2,8], z_dims=[13,45], is_convex=True)
    

                
    # model = SCNN(input_size = 6,
    #             x_c_features = [2],
    #             x_s_features = [0,1],
    #             u_dims = (5,1),
    #             z_dims = (11,1),
    #             is_convex = False)

    # # number of model parameters
    # param_amount = 0
    # for p in model.named_parameters():
    #     print(p[0], p[1].numel())
    #     param_amount += p[1].numel()
    # print('total param amount:', param_amount)

    # out = model(features)
    # print("Output shape:", out.shape)